{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成 ModelProto\n",
    "\n",
    "此示例展示了如何使用 onnxscript 定义 ONNX 模型。onnxscript 的行为类似于编译器。它将 `script` 转换为 ONNX 模型。\n",
    "\n",
    "首先，我们在 onnxscript 中定义平方损失函数的实现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import onnx\n",
    "from onnxruntime import InferenceSession\n",
    "\n",
    "from onnxscript import FLOAT, script\n",
    "from onnxscript import opset15 as op\n",
    "\n",
    "\n",
    "@script()\n",
    "def square_loss(X: FLOAT[\"N\", 1], Y: FLOAT[\"N\", 1]) -> FLOAT[1, 1]:  # noqa: F821\n",
    "    diff = X - Y\n",
    "    return op.ReduceSum(diff * diff, keepdims=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以按照以下方式将其转换为模型（ ONNX ModelProto）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = square_loss.to_model_proto()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<\n",
      "   ir_version: 8,\n",
      "   opset_import: [\"\" : 15]\n",
      ">\n",
      "square_loss (float[N,1] X, float[N,1] Y) => (float[1,1] return_val) {\n",
      "   diff = Sub (X, Y)\n",
      "   tmp = Mul (diff, diff)\n",
      "   return_val = ReduceSum <keepdims: int = 1> (tmp)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(onnx.printer.to_text(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以使用标准的 ONNX API 对模型进行形状推断和类型检查。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = onnx.shape_inference.infer_shapes(model)\n",
    "onnx.checker.check_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，我们可以使用标准的 onnxruntime API 通过 onnxruntime 计算该模型的输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.13999999 [array([[0.13999999]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "sess = InferenceSession(model.SerializeToString(), providers=(\"CPUExecutionProvider\",))\n",
    "\n",
    "X = np.array([[0, 1, 2]], dtype=np.float32).T\n",
    "Y = np.array([[0.1, 1.2, 2.3]], dtype=np.float32).T\n",
    "\n",
    "got = sess.run(None, {\"X\": X, \"Y\": Y})\n",
    "expected = ((X - Y) ** 2).sum()\n",
    "\n",
    "print(expected, got)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
